from pathlib import Path
import random

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import torchvision.transforms as transforms

from examples.SUBTRACTION.data.extended_imagefolder import RegressionImageFolder, JointImageFolder


IMSIZE = 128

def division(coords, digits):
    c1, c2 = coords
    d1, d2 = digits
    try:
        if c1 >= c2:
            return d1 / d2
        return d2 / d1
    except ZeroDivisionError:
        print(d1, d2)
        print(c1, c2)

def subtraction(coords, digits):
    c1, c2 = coords
    d1, d2 = digits
    if c1 >= c2:
        return d2 - d1
    return d1 - d2

def batch_subtraction(coords, digits):
    cs = tf.stack(coords)
    ds = tf.stack(digits)
    argmax = tf.argmax(cs, axis=0)
    argmin = 1 - argmax
    dsmax = []
    dsmin = []
    for i in range(len(argmax)):
        dsmax.append(ds[(argmax[i], i)])
        dsmin.append(ds[(argmin[i], i)])
    dsmax = tf.stack(dsmax)
    dsmin = tf.stack(dsmin)
    return tf.subtract(dsmax, dsmin)

def batch_joint_subtraction(digits):
    return tf.subtract(digits[:, 0], digits[:, 1])


def load_dataset(batch_size, samedistr=False):
    _DATA_ROOT = Path(__file__).parent

    transform = transforms.Compose(
        [
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    )

    if samedistr:
        datasets = {
            "train": JointImageFolder(
                root='examples/SUBTRACTION/data/multimnist_smaller_aligned_xy_samedistr/train',
                transform=transform, batch_size=batch_size
            ),
            "val": JointImageFolder(
                root='examples/SUBTRACTION/data/multimnist_smaller_aligned_xy_samedistr/val',
                transform=transform, batch_size=batch_size
            ),
            "test": JointImageFolder(
                root='examples/SUBTRACTION/data/multimnist_smaller_aligned_xy_samedistr/test',
                transform=transform, batch_size=batch_size
            )
        }
    else:
        datasets = {
            "train": JointImageFolder(
                root='examples/SUBTRACTION/data/multimnist_smaller_aligned_xy/train',
                transform=transform, batch_size=batch_size
            ),
            "val": JointImageFolder(
                root='examples/SUBTRACTION/data/multimnist_smaller_aligned_xy/val',
                transform=transform, batch_size=batch_size
            ),
            "test": JointImageFolder(
                root='examples/SUBTRACTION/data/multimnist_smaller_aligned_xy/test',
                transform=transform, batch_size=batch_size
            )
        }

    return datasets


def create_joint_dataloader(dataset_name, seed=7, curriculum=False, batch_size=10, data_size=1000,
                            generative=False, samedistr=False, regression=True, constraints=False):
    datasets = load_dataset(batch_size, samedistr=samedistr)
    dataset = datasets[dataset_name]
    image_indices = list(range(len(dataset)))

    if seed is not None:
        rng = random.Random(seed)
        rng.shuffle(image_indices)

    dataset_iter = iter(image_indices)
    data = []
    try:
        while dataset_iter:
            data.append(
                [
                    [next(dataset_iter)]
                ]
            )
    except StopIteration:
        pass

    left_multiplier = np.ones([batch_size, IMSIZE, IMSIZE, 1])
    right_multiplier = np.ones([batch_size, IMSIZE, IMSIZE, 1])

    primordial_dl = []
    size = data_size // batch_size
    if not curriculum:
        for i in data[:size]:
            I = tf.reshape(tf.constant(dataset[i[0][0]][0].numpy()), [batch_size, IMSIZE, IMSIZE, 1])
            ds = tf.constant(dataset[i[0][0]][1].numpy())

            expected_out = tf.constant(batch_joint_subtraction(ds))
            if regression:
                diss = tf.constant(dataset[i[0][0]][2].numpy())
                diss += 14/128
                distances = tf.sqrt((diss[:, 0] - diss[:, 1]) ** 2 + (diss[:, 2] - diss[:, 3]) ** 2)
                # origindistances = tf.sqrt((diss[:, 0]) ** 2 + (diss[:, 2]) ** 2)
                # topleftdistances = tf.sqrt((diss[:, 0]) ** 2 + (diss[:, 2] - 1) ** 2)

                # x1 = tf.cast((diss[:, 1][0] - 0.10) * 128, tf.int32)
                # x2 = tf.cast((diss[:, 1][0] + 0.10) * 128, tf.int32)
                # y1 = tf.cast((diss[:, 3][0] - 0.10) * 128, tf.int32)
                # y2 = tf.cast((diss[:, 3][0] + 0.10) * 128, tf.int32)
                # plt.imshow(I[0][y1: y2, x1:x2, 0])
                # plt.show()

                if constraints:
                    primordial_dl.append(
                        [[I, expected_out, distances],
                         tf.constant([1.] * len(ds), dtype=tf.float32)]
                    )
                else:
                    primordial_dl.append(
                        [[I, ds[:, 0], ds[:, 1], diss[:, 0], diss[:, 1]],
                         tf.constant([1.] * len(ds), dtype=tf.float32)]
                    )

                # primordial_dl.append(
                #     [[I, tf.cast(expected_out, dtype=tf.float32), distances],
                #      tf.constant([1.] * len(ds), dtype=tf.float32)])
            else:
                primordial_dl.append(
                    [[I, tf.cast(expected_out, dtype=tf.float32)], tf.constant([1.] * len(ds), dtype=tf.float32)])

    else:
        for i in data[:size]:
            I = tf.reshape(tf.constant(dataset[i[0][0]][0].numpy()), [batch_size, IMSIZE, IMSIZE, 1])

            if generative:
                primordial_dl.append(
                    [[I, tf.minimum(I, left_multiplier), tf.minimum(I, right_multiplier)],
                     tf.constant([1.] * batch_size, dtype=tf.float32)]
                )
            else:
                ds = tf.constant(dataset[i[0][0]][1].numpy())

                if regression:
                    diss = tf.constant(dataset[i[0][0]][2].numpy()) + 14/128

                    primordial_dl.append(
                        [[I, ds[:, 0], ds[:, 1], diss[:, 0], diss[:, 1], diss[:, 2], diss[:, 3]], tf.constant([1.] * len(ds), dtype=tf.float32)]
                    )
                else:
                    primordial_dl.append(
                        [[I, ds[:, 0], ds[:, 1]], tf.constant([1.] * len(ds), dtype=tf.float32)]
                    )

    return primordial_dl

